/*
 * Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/kernels/speculativeDecoding/common.h"
#include "tensorrt_llm/runtime/common.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>

namespace tensorrt_llm::kernels::speculative_decoding
{

//! \brief Sets pointers to logits in logitsPtrs according to the draftDecodingTokens.
//! \param logitsPtrs [batchSize]. Each element points to a [vocabSizePadded] buffer.
//! \param decodingTokens [batchSize], on GPU. draftDecodingTokens + 1.
//! \param logits [numTokens, vocabSizePadded], on GPU. Continuous logits in memory.
//! \param draftDecodingTokens [batchSize], on GPU. 0 for context requests, and actual draft len for gen requests.
//! \param batchSize SizeType32. Batch size.
//! \param maxDecodingTokens SizeType32. Maximum number of decoding tokens per step per request.
//! \param vocabSizePadded SizeType32. Vocabulary size of the logits.
//! \param stream cuda stream
template <typename T>
void invokeAssembleTargetLogitsOffsets(T const** logitsPtrs, runtime::SizeType32* decodingTokens, T const* logits,
    runtime::SizeType32 const* draftDecodingTokens, runtime::SizeType32 batchSize,
    runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 vocabSizePadded, cudaStream_t stream);

//! FIXME: We may get rid of this kernel in future optimization
//! \brief Set the logitsPtrs[numValidLogits][1, vocabSizePadded] from logits [numInputLogits * vocabSizePadded]
//! and outputIdsPtrs[numValidLogits][maxDecodingDraftTokens] from outputIds[numInputLogits * maxDecodingDraftTokens]
//! Can be merged into other kernels.
//! \param logitsPtrs [numValidLogits][1, vocabSizePadded], on GPU. The logits pointer array that will be used in topK
//! sampling.
//! \param logits [numInputLogits * vocabSizePadded], on GPU. Flatten logits, generated by the EagleNet.
//! \param outputIdsPtrs [numValidLogits][maxDecodingDraftTokens], on GPU. The output buffer of the topK sampling.
//! \param outputIds [numInputLogits * maxDecodingDraftTokens], on GPU. The flatten output buffer.
//! \param skipDecode [batchSize * maxNonLeavesPerLayer], on GPU. Flag whether to skip decoding or not.
//! First batchSize * sum(numValidLogitsPerRequest[:]) are set to true, the rest is false.
//! \param numValidLogits [1], on GPU. Number of valid logits.
//! \param numInputLogits SizeType32. Number of logits from all the requests.
//! \param maxDecodingDraftTokens SizeType32. Maximum number of decoding draft tokens per step per request.
//! \param vocabSizePadded SizeType32. Vocabulary size of the logits.
//! \param stream cuda stream
template <typename T>
void invokeAssembleDraftLogitsOffsets(T const** logitsPtrs, T const* logits, runtime::TokenIdType** outputIdsPtrs,
    runtime::TokenIdType* outputIds, bool* skipDecode, runtime::SizeType32 const* numValidLogits,
    runtime::SizeType32 numInputLogits, runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingDraftTokens,
    runtime::SizeType32 vocabSizePadded, cudaStream_t stream);

//! \brief Prepares data for ctx stage EagleNet (EagleNet0).
//! EagleNet0 is always chunked context attn,
//! where we process either context tokens of the ctx requests or
//! newly accepted tokens from base model and append them to EagleNet KV cache.
//! For input/output examples visit test/model/eagle/test_prepare_drafter_inputs_plugin.py (ctx Eagle Net examples)
//! \param eagleNetSequenceLengths output buffer [batchSize]
//! Sequence length for the EagleNet0.
//! \param eagleNetContextLengths output buffer [batchSize]
//! Context lengths for the EagleNet0.
//! \param outputIds output buffer [numOutputTokens], flattened selected tokens ids without padding.
//! \param positionIds output buffer [numOutputTokens], flattened selected pos ids without padding
//! \param hiddenStatesIndices output buffer [numOutputTokens],
//! indices of the hidden states for selected tokens for the next EagleNet iteration.
//! E.g. With 3 requests where the first two are context requests with lengths 5 and 3 respectively and the 3rd
//! is gen request with draftDecodingTokens=8 and acceptedLength=3 and the best path is [0, 2, 5].
//! hiddenStatesIndices equals to [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 13].
//! \param lastTokenIndices output buffer [batchSize * maxNonLeavesPerLayer],
//! Indices (starting with 1) of the logits of interest after the EagleNet prediction.
//! Used for index_select of the hidden_states in the end of the EagleNet. Padded to maxNonLeavesPerLayer with 1s.
//! \param numLastTokenIndices output buffer [1], number of logits predicted by the next EagleNet
//! iteration. For EagleNet0 each value is batchSize.
//! \param hiddenSizeBatchLevelStarts output buffer [batchSize * maxDraftPathLen + 1]
//! Exclusive sum of the hidden states produced per batch per level.
//! For EagleNet0 it is just cum sum of 1s for batchSize.
//! \param inputIds input buffer [numTokens], input ids (inputs of the Base model)
//! \param chunkedContextNextTokens input buffer [batchSize], first token from the next chunk in the chunked context
//! or -1 if current chunk is the last chunk (or not context phase).
//! \param baseNetSequenceLengths input buffer [batchSize] sequence lengths (inputs of the Base model).
//! \param baseNetContextLengths input buffer [batchSize], context lengths (inputs of the Base model).
//! \param acceptedTokens input buffer [batchSize, maxPathLen], ids of the accepted tokens.
//! \param acceptedLens input buffer [batchSize], on GPU. Number of accepted tokens.
//! \param prevDraftLens input buffer [batchSize], on GPU. Number of draft tokens (inputs of the Base model).
//! 0 for ctx requests and actual draft len for gen requests.
//! \param prevPaths input buffer [batchSize, maxDecodingTokens, maxPathLen], on GPU.
//! Previous paths (inputs of the Base model).
//! \param bestPathIds input buffer [batchSize], on GPU. Indices of the accepted path in prevPaths
//! \param batchSize batch size
//! \param maxPathLen Max number of accepted tokens per step
//! \param maxDecodingTokens Max number of draft tokens + 1
//! \param maxNonLeavesPerLayer Maximum number of non-leaf nodes per layer
//! \param stream cuda stream.
void invokePrepareCtxEagleNetInputs(runtime::SizeType32* eagleNetSequenceLengths,
    runtime::SizeType32* eagleNetContextLengths, runtime::TokenIdType* outputIds, runtime::SizeType32* positionIds,
    runtime::SizeType32* hiddenStatesIndices, runtime::SizeType32* lastTokenIndices,
    runtime::SizeType32* numLastTokenIndices, runtime::SizeType32* hiddenSizeBatchLevelStarts,
    runtime::TokenIdType const* inputIds, runtime::TokenIdType const* chunkedContextNextTokens,
    runtime::SizeType32 const* baseNetSequenceLengths, runtime::SizeType32 const* baseNetContextLengths,
    runtime::TokenIdType const* acceptedTokens, runtime::SizeType32 const* acceptedLens,
    runtime::SizeType32 const* prevDraftLens, runtime::SizeType32 const* prevPaths,
    runtime::SizeType32 const* bestPathIds, runtime::SizeType32 batchSize, runtime::SizeType32 maxPathLen,
    runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 maxNonLeavesPerLayer, cudaStream_t stream);

struct PrepareGenEagleNetInputsParams
{
    //! output buffer [batchSize]
    //! Sequence length for the next EagleNet iteration.
    //! Equals to EagleNet0 seqLen + specDecodingGenLengths
    runtime::SizeType32* nextSequenceLengths{nullptr};
    //! output buffer [batchSize]
    //! Context length for the next EagleNet iteration.
    //! Equals to prevContextLengths
    runtime::SizeType32* nextContextLengths{nullptr};
    //! output buffer [numOutputTokens]
    //! Selected tokens ids.
    runtime::TokenIdType* outputIds{nullptr};
    //! output buffer [batchSize]
    //! Position ids of the selected tokens.
    runtime::SizeType32* positionIds{nullptr};
    //! output buffer [batchSize]
    //! Number of the draft tokens per requert.
    runtime::SizeType32* specDecodingGenLengths{nullptr};
    //! output buffer [batchSize, maxDecodingTokens]
    //! Positions offsets (relative depth in the tree) of the selected tokens.
    runtime::SizeType32* specDecodingPositionOffsets{nullptr};
    //! output buffer [batchSize, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
    //! uint32 packed mask of the draft tokens per request.
    runtime::SizeType32* specDecodingPackedMasks{nullptr};
    //! output buffer [numOutputTokens]
    //! Indices of the hidden states for selected tokens for the next EagleNet iteration.
    runtime::SizeType32* hiddenStatesIndices{nullptr};
    //! output buffer [batchSize * maxNonLeavesPerLayer]
    //! Indices of the hidden states where to sample logits from after the next EagleNet iteration.
    runtime::SizeType32* lastTokenIndices{nullptr};
    //! output buffer [1]
    //! Number of logits predicted by the next EagleNet iteration.
    runtime::SizeType32* numLastTokenIndices{nullptr};
    //! input buffer [(maxPathLen - 1) * batchSize + 1]
    //! Exclusive sum of the hidden states produced per batch per level.
    //! Same as inputHiddenSizeBatchStartsPerLevel, but also with data appended for cur level.
    runtime::SizeType32* outputHiddenSizeBatchStartsPerLevel{nullptr};

    // Workspace buffers
    //! [batchSize, maxDecodingTokens]
    //! Boolean mask to mark node as leaf or not.
    int8_t* isLeafMask{nullptr};
    //! [batchSize, maxDecodingDraftTokens]
    //! Indices of the draft tokens in the nextDraftIds selected at current level.
    runtime::SizeType32* selectedDraftIndices{nullptr};
    //! [batchSize, maxDecodingDraftTokens]
    //! Position offsets of the selected draft tokens.
    runtime::SizeType32* selectedDraftPosOffsets{nullptr};
    //! [batchSize]
    //! Number of selected tokens.
    runtime::SizeType32* numSelectedDraftIndices{nullptr};
    //! [batchSize, maxDecodingTokens, maxDecodingTokens]
    //! Boolean (not packed) mask of the selected draft tokens.
    bool* selectedMasks{nullptr};
    //! [batchSize + 1]
    runtime::SizeType32* cumSumGenerationLengths{nullptr};
    //! [1]
    runtime::SizeType32* maxGenerationLength{nullptr};
    //! [batchSize, maxDecodingTokens]
    runtime::SizeType32* nonLeavesInLevelOffsets{nullptr};
    //! [batchSize, maxDecodingTokens]
    runtime::SizeType32* parentNonLeafInLevelOffset{nullptr};

    //! input buffer [batchSize, maxDecodingDraftTokens]
    //! Drafted draft tokens. All tokens for the next Base model itertion are in the same buffer.
    runtime::TokenIdType const* nextDraftIds{nullptr};
    //! input buffer [batchSize]
    //! Sequence lengths after the ctx EagleNet0.
    runtime::SizeType32 const* eagleNet0SequenceLengths{nullptr};
    //! input buffer [batchSize]
    //! Context lengths after the ctx EagleNet0.
    runtime::SizeType32 const* prevContextLengths{nullptr};
    //! input buffer [batchSize, maxDecodingTokens, maxPathLen]
    //! Draft paths for the next iteration of the Base model. We use these paths to assemble output ids.
    runtime::SizeType32 const* nextPaths{nullptr};
    //! input buffer [(maxPathLen - 1) * batchSize + 1]
    //! Exclusive sum of the hidden states sizes per batch per layer.
    //! E.g. with BS=2, r0 and r1 have 1 hidden state at level 0 (golden token).
    //! r0 has 2 hidden states and r1 has 3 hidden states at level 1.
    //! Thus, hidden states are placed in memory as
    //! [h_0_0_0, h_0_0_1, h_0_1_0, h_1_1_0, h_0_1_1, h_1_1_1, h_2_1_1], where
    //! h_i_j_k means ith hidden state of request k at level j.
    // hiddenSizeBatchStartsPerLevel equals to [0, 1, 2, 4]
    runtime::SizeType32 const* inputHiddenSizeBatchStartsPerLevel{nullptr};

    //! Tree level index. Same as gen iter of the EagleNet. For gen EagleNet it is >= 1 and < maxPathLen - 1
    runtime::SizeType32 levelIdx{0};
    //! Batch size
    runtime::SizeType32 batchSize{0};
    //! Max number of accepted tokens per step
    runtime::SizeType32 maxPathLen{0};
    //! Max number of draft tokens + 1
    runtime::SizeType32 maxDecodingTokens{0};
    //! Maximum number of non-leaf nodes per layer
    runtime::SizeType32 maxNonLeavesPerLayer{0};
    cudaStream_t stream;

    void checkParams()
    {
        TLLM_CHECK(nextSequenceLengths);
        TLLM_CHECK(nextContextLengths);
        TLLM_CHECK(outputIds);
        TLLM_CHECK(positionIds);
        TLLM_CHECK(specDecodingGenLengths);
        TLLM_CHECK(specDecodingPositionOffsets);
        TLLM_CHECK(specDecodingPackedMasks);
        TLLM_CHECK(hiddenStatesIndices);
        TLLM_CHECK(lastTokenIndices);
        TLLM_CHECK(numLastTokenIndices);
        TLLM_CHECK(outputHiddenSizeBatchStartsPerLevel);

        TLLM_CHECK(isLeafMask);
        TLLM_CHECK(selectedDraftIndices);
        TLLM_CHECK(selectedDraftPosOffsets);
        TLLM_CHECK(numSelectedDraftIndices);
        TLLM_CHECK(selectedMasks);
        TLLM_CHECK(cumSumGenerationLengths);
        TLLM_CHECK(maxGenerationLength);
        TLLM_CHECK(nonLeavesInLevelOffsets);
        TLLM_CHECK(parentNonLeafInLevelOffset);

        TLLM_CHECK(nextDraftIds);
        TLLM_CHECK(eagleNet0SequenceLengths);
        TLLM_CHECK(prevContextLengths);
        TLLM_CHECK(nextPaths);
        TLLM_CHECK(inputHiddenSizeBatchStartsPerLevel);

        TLLM_CHECK(batchSize > 0);
        TLLM_CHECK(maxPathLen > 0);
        TLLM_CHECK(maxDecodingTokens > 0);
        TLLM_CHECK(0 < levelIdx && levelIdx < maxPathLen - 1);
        TLLM_CHECK(maxNonLeavesPerLayer > 0);
    }
};

//! \brief Prepares inputs for the gen stage EagleNet itearion (layerIdx > 0).
//! For input/output examples visit test/model/eagle/test_prepare_drafter_inputs_plugin.py (gen Eagle Net examples)
void invokePrepareGenEagleNetInputs(PrepareGenEagleNetInputsParams const& params);

struct PackEagleParams
{
    runtime::SizeType32 batchSize{0};
    runtime::SizeType32 maxNumPaths{0};
    runtime::SizeType32 maxDecodingTokens{0};
    runtime::SizeType32 maxPathLength{0};
    runtime::SizeType32 numContextRequests{0};
    runtime::SizeType32 numGenerationRequests{0};

    //! inputs
    //! [batchSize]
    runtime::SizeType32 const* batchSlots{nullptr};

    //! [maxBatchSize]
    float const* inputTemperatures{nullptr};
    //! [maxBatchSize]
    float const* inputRandomDataSample{nullptr};
    //! [maxBatchSize, maxDecodingTokens]
    float const* inputRandomDataValidation{nullptr};
    //! [maxBatchSize, maxDecodingDraftTokens]
    runtime::TokenIdType const* inputNextDraftTokens{nullptr};
    //! [maxBatchSize, maxDecodingTokens, maxPathLen]
    runtime::SizeType32 const* inputNextDraftPaths{nullptr};
    //! [maxBatchSize]
    runtime::SizeType32 const* inputSpecDecodingGenerationLengths{nullptr};
    //! [maxBatchSize]
    runtime::SizeType32 const* inputSpecDecodingPositionOffsets{nullptr};
    //! [maxBatchSize, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
    int32_t const* inputSpecDecodingPackedMasks{nullptr};

    //! outputs
    //! [batchSize]
    float* outputTemperatures{nullptr};
    //! [batchSize]
    float* outputRandomDataSample{nullptr};
    //! [batchSize, maxDecodingTokens]
    float* outputRandomDataValidation{nullptr};
    //! [batchSize, maxDecodingDraftTokens]
    runtime::TokenIdType* outputNextDraftTokens{nullptr};
    //! [batchSize]
    runtime::SizeType32* outputNextDraftLens{nullptr};
    //! [batchSize, maxDecodingTokens, maxPathLen]
    runtime::SizeType32* outputNextDraftPaths{nullptr};
    //! [batchSize]
    runtime::SizeType32* outputSpecDecodingGenerationLengths{nullptr};
    //! [batchSize]
    runtime::SizeType32* outputSpecDecodingPositionOffsets{nullptr};
    //! [maxBatchSize, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
    int32_t* outputSpecDecodingPackedMasks{nullptr};

    // workspace
    //! [1]
    runtime::SizeType32* maxGenerationLength{nullptr};
    //! [batchSize + 1]
    runtime::SizeType32* cumSumGenerationLengths{nullptr};

    void checkParams()
    {
        TLLM_CHECK(batchSlots);

        TLLM_CHECK(inputTemperatures);
        TLLM_CHECK(inputRandomDataSample);
        TLLM_CHECK(inputRandomDataValidation);
        TLLM_CHECK(inputNextDraftTokens);
        TLLM_CHECK(inputNextDraftPaths);
        TLLM_CHECK(inputSpecDecodingGenerationLengths);
        TLLM_CHECK(inputSpecDecodingPositionOffsets);
        TLLM_CHECK(inputSpecDecodingPackedMasks);

        TLLM_CHECK(outputTemperatures);
        TLLM_CHECK(outputRandomDataSample);
        TLLM_CHECK(outputRandomDataValidation);
        TLLM_CHECK(outputNextDraftTokens);
        TLLM_CHECK(outputNextDraftLens);
        TLLM_CHECK(outputNextDraftPaths);
        TLLM_CHECK((numGenerationRequests > 0 && outputSpecDecodingGenerationLengths) || numGenerationRequests == 0);
        TLLM_CHECK((numGenerationRequests > 0 && outputSpecDecodingPositionOffsets) || numGenerationRequests == 0);
        TLLM_CHECK((numGenerationRequests > 0 && outputSpecDecodingPackedMasks) || numGenerationRequests == 0);

        TLLM_CHECK(maxGenerationLength);
        TLLM_CHECK(cumSumGenerationLengths);

        TLLM_CHECK(batchSize > 0);
        TLLM_CHECK(batchSize == numContextRequests + numGenerationRequests);
        TLLM_CHECK(maxDecodingTokens > 0);
        TLLM_CHECK(maxPathLength > 0);
        TLLM_CHECK(maxNumPaths > 0);
    }
};

//! \brief packs outputSpecDecodingGenerationLengths from batch slots positions to continuous memory.
void invokePackEagleGenerationLengths(PackEagleParams const& params, cudaStream_t stream);
//! \brief packs the rest of the output tensors from batch slots positions to continuous memory.
void invokePackEagle(PackEagleParams const& params, cudaStream_t stream);

struct UnpackEagleDataParams
{
    //! [batchSize]
    runtime::SizeType32 const* batchSlots{nullptr};
    //! [maxBatchSize]
    curandState_t* inputCurandState{nullptr};
    //! [maxBatchSize]
    float const* inputTemperatures{nullptr};
    //! [batchSize, maxDecodingDraftTokens]
    runtime::TokenIdType const* inputNextDraftTokens{nullptr};
    //! [batchSize]
    runtime::SizeType32 const* inputNextDraftLens{nullptr};
    //! [batchSize, maxDecodingTokens, maxPathLen]
    runtime::SizeType32 const* inputNextDraftPaths{nullptr};
    //! [batchSize, maxDecodingDraftTokens]
    runtime::TokenIdType const* inputLastDraftTokens{nullptr};
    //! [batchSize]
    runtime::SizeType32 const* inputLastDraftLens{nullptr};
    //! [batchSize, maxPathLen]
    runtime::TokenIdType const* inputAcceptedTokens{nullptr};
    //! [batchSize]
    runtime::SizeType32 const* inputAcceptedLens{nullptr};

    //! [maxBatchSize, maxSeqLen]
    runtime::TokenIdType* outputIds{nullptr};
    //! [maxBatchSize]
    runtime::SizeType32* outputNumNewTokens{nullptr};
    //! [maxBatchSize]
    runtime::SizeType32* outputSequenceLengths{nullptr};
    //! [maxBatchSize, maxDecodingDraftTokens]
    runtime::TokenIdType* outputUnpackedNextDraftTokens{nullptr};
    //! [maxBatchSize, maxDecodingDraftTokens]
    runtime::TokenIdType* outputNextDraftTokens{nullptr};
    //! [maxBatchSize]
    runtime::SizeType32* outputNextDraftLengths{nullptr};
    //! [maxBatchSize, maxDecodingTokens, maxPathLen]
    runtime::SizeType32* outputNextDraftPaths{nullptr};
    //! [maxBatchSize]
    runtime::SizeType32* outputPrevDraftLengths{nullptr};
    //! [maxBatchSize]
    runtime::SizeType32* outputNextGenerationLength{nullptr};
    //! [maxBatchSize, maxDecodingTokens]
    runtime::SizeType32* outputPositionIds{nullptr};

    //! [maxBatchSize]
    float* outputRandDataSample{nullptr};
    //! [maxBatchSize, maxDecodingTokens]
    float* outputRandDataVerification{nullptr};
    //! [maxBatchSize]
    float* outputTemperatures{nullptr};

    //! [maxBatchSize]
    runtime::SizeType32* outputEagleNetCtxRequestTypes{nullptr};
    //! [maxBatchSize]
    runtime::SizeType32* outputEagleNetCtxContextLengths{nullptr};
    //! [maxBatchSize]
    runtime::SizeType32* outputEagleNetCtxPastKeyValueLengths{nullptr};
    //! [maxBatchSize]
    runtime::SizeType32* outputEagleNetGenRequestTypes{nullptr};
    //! [maxBatchSize]
    runtime::SizeType32* outputEagleNetGenContextLengths{nullptr};
    //! [maxBatchSize]
    runtime::SizeType32* outputEagleNetGenPastKeyValueLengths{nullptr};

    runtime::SizeType32 batchSize{0};
    runtime::SizeType32 maxDecodingTokens{0};
    runtime::SizeType32 maxPathLength{0};
    runtime::SizeType32 maxSeqLen{0};

    void checkParams()
    {
        TLLM_CHECK(batchSlots);
        TLLM_CHECK(inputCurandState);
        TLLM_CHECK(inputTemperatures);
        TLLM_CHECK(inputNextDraftTokens);
        TLLM_CHECK(inputNextDraftLens);
        TLLM_CHECK(inputNextDraftPaths);
        TLLM_CHECK(inputLastDraftTokens);
        TLLM_CHECK(inputLastDraftLens);
        TLLM_CHECK(inputAcceptedTokens);
        TLLM_CHECK(inputAcceptedLens);

        TLLM_CHECK(outputIds);
        TLLM_CHECK(outputNumNewTokens);
        TLLM_CHECK(outputSequenceLengths);
        TLLM_CHECK(outputUnpackedNextDraftTokens);
        TLLM_CHECK(outputNextDraftTokens);
        TLLM_CHECK(outputNextDraftLengths);
        TLLM_CHECK(outputNextDraftPaths);
        TLLM_CHECK(outputPrevDraftLengths);
        TLLM_CHECK(outputNextGenerationLength);
        TLLM_CHECK(outputPositionIds);

        TLLM_CHECK(outputRandDataSample);
        TLLM_CHECK(outputRandDataVerification);
        TLLM_CHECK(outputTemperatures);

        TLLM_CHECK(outputEagleNetCtxRequestTypes);
        TLLM_CHECK(outputEagleNetCtxContextLengths);
        TLLM_CHECK(outputEagleNetCtxPastKeyValueLengths);
        TLLM_CHECK(outputEagleNetGenRequestTypes);
        TLLM_CHECK(outputEagleNetGenContextLengths);
        TLLM_CHECK(outputEagleNetGenPastKeyValueLengths);

        TLLM_CHECK(batchSize > 0);
        TLLM_CHECK(maxDecodingTokens > 0);
        TLLM_CHECK(maxPathLength > 0);
        TLLM_CHECK(maxSeqLen > 0);
    }
};

//! \brief unpacks outputs of the engine from continuous memory to batch slots.
void invokeUnpackEagleData(UnpackEagleDataParams const& params, cudaStream_t stream);

struct FillContextEagleParams
{
    //! [maxBatchSize]
    float* outputRandDataSample{nullptr};
    //! [maxBatchSize]
    float* outputTemperatures{nullptr};

    //! [maxBatchSize]
    float const* inputTemperatures{nullptr};
    //! [maxBatchSize]
    curandState_t* inputCurandState{nullptr};
    //! [batchSize]
    runtime::SizeType32 const* batchSlots{nullptr};

    runtime::SizeType32 batchSize{0};

    void checkParams()
    {
        TLLM_CHECK(outputRandDataSample);
        TLLM_CHECK(outputTemperatures);

        TLLM_CHECK(inputTemperatures);
        TLLM_CHECK(inputCurandState);
        TLLM_CHECK(batchSlots);

        TLLM_CHECK(batchSize > 0);
    }
};

//! \brief fills necessary buffers before the Eagle context phase.
void invokeFillContextEagleData(FillContextEagleParams const& params, cudaStream_t stream);

//! \brief extract mask from paths and pack it to int32_t bit masks located at slots.
void invokeGetPackedMaskFromPath(int32_t* specDecodingPackedMasks, runtime::SizeType32 const* batchSlots,
    runtime::SizeType32 const* nextDraftPaths, runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingTokens,
    runtime::SizeType32 maxPathLen, cudaStream_t stream);

//! \brief Extract the TopKs from tree of specific level (layerId).
//! \param paths [batchSize, maxDecodingTokens, maxPathLen], on GPU. Indices of the draft sequences.
//! \param topKs [numInputLogits], on GPU. The topK value for each input logits.
//! \param topKOffset [batchSize], on GPU. The topK start offset for each request. Will be used to slice the output
//! draft tokens.
//! \param numSuccessorsForEachNode [batchSize][maxDecodingTokens], on GPU. Record the number of
//! successors of each node from the corresponding tree for each requests.
//! \param layerId SizeType32. The layerId of the eagle net. Will be used to traverse a specific level of
//! the tree.
//! \param batchSize SizeType32. Batch size.
//! \param numInputLogits SizeType32. Number of logits from all the requests.
//! \param maxDecodingTokens SizeType32. Maximum number of decoding tokens per step per request.
//! \param maxPathLen SizeType32. Maximum path len of the draft sequence.
//! \param stream cuda stream.
void invokeExtractTopKsFromPath(runtime::SizeType32 const* paths, runtime::SizeType32* topKs,
    runtime::SizeType32* topKOffset, runtime::SizeType32* numSuccessorsForEachNode, runtime::SizeType32 layerId,
    runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 maxPathLen,
    cudaStream_t stream);

//! \brief Copy the output draft token from input buffer (generated from previous EagleNets)
//! and new draft tokens generated by this layers to the output buffer of this plugin
//! also update the draft length and copy paths from input to output buffer.
//! \param tmpOutputIdsPtrs [numInputLogits][maxDecodingDraftTokens], on GPU. The temporary output buffer of the topK
//! sampling.
//! \param topKs [numInputLogits], on GPU. The topK value for each input logits.
//! \param topKOffset [batchSize], on GPU. The topK start offset for each request. Will be used to slice the output
//! draft tokens.
//! \param pluginInputDraftIdsPtrs [batchSize * maxDecodingDraftTokens], on GPU. The plugin's input buffer,
//! which contains draft tokens generated by previous EagleNets.
//! \param pluginInputDraftLens [batchSize], on GPU. The
//! plugin's input buffer, which contains the draft length from previous EagleNets.
//! \param numValidLogits [1], on GPU. The number of valid logits.
//! \param pluginOutputDraftIdsPtrs [batchSize * maxDecodingDraftTokens], on GPU. The plugin's output buffer,
//! which will contains all the draft tokens generated by this and previous EagleNets.
//! \param pluginOutputDraftLens [batchSize], on GPU. The plugin's input buffer,
//! which contains the draft length for the draft tokens.
//! \param layerId SizeType32. The layerId of the EagleNet. Will
//! be used to traverse a specific level of the tree.
//! \param batchSize SizeType32. Batch size.
//! \param maxDecodingDraftTokens SizeType32. maximum number of decoding draft tokens per step per request.
//! \param inputPaths [batchSize, maxDecodingTokens, maxPathLen], on GPU. Input paths.
//! \param outputPaths [batchSize, maxDecodingTokens, maxPathLen], on GPU. Output paths.
//! \param maxPathLen SizeType32. Maximum path len of the draft sequence.
//! \param stream cuda stream.
void invokeCopyOutputTokensIds(runtime::TokenIdType const* const* tmpOutputIdsPtrs, runtime::SizeType32 const* topKs,
    runtime::SizeType32 const* topKOffset, runtime::TokenIdType const* pluginInputDraftIdsPtrs,
    runtime::SizeType32 const* pluginInputDraftLens, runtime::SizeType32 const* numValidLogits,
    runtime::TokenIdType* pluginOutputDraftIdsPtrs, runtime::SizeType32* pluginOutputDraftLens,
    runtime::SizeType32 layerId, runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingDraftTokens,
    runtime::SizeType32 const* inputPaths, runtime::SizeType32* outputPaths, runtime::SizeType32 maxPathLen,
    cudaStream_t stream);

//! \brief Augment seq slots so that non-last chunks are set to -1 (if chunkedContextNextTokens != -1).
//!
//! \param augmentedSeqSlots output buffer [engineBatchSize]
//! \param chunkedContextNextTokens input buffer [engineBatchSize], indicator of the not last chunk of the ctx
//! requests. -1 for gen requests and last chunk, != -1 otherwise.
//! \param lastDraftLens input buffer [engineBatchSize], number of draft tokens input to the current iteration.
//! 0 for ctx requests and > 0 for gen requests.
//! \param seqSlots input buffer [engineBatchSize], address map from local index to global index [0, batchSize]
//! -> [0, maxBatchSize]
//! \param engineBatchSize number of sequences processed in the engine.
//! Includes chunked context reqs that are not in the last chunk.
//! \param batchSize the number of sequences to be decoded
//! \param stream cuda stream.
void invokeAugmentBatchSlots(runtime::SizeType32* augmentedSeqSlots,
    runtime::SizeType32 const* chunkedContextNextTokens, runtime::SizeType32 const* lastDraftLens,
    runtime::SizeType32 const* seqSlots, runtime::SizeType32 engineBatchSize, runtime::SizeType32 batchSize,
    cudaStream_t stream);

//! \brief For Eagle-2, set topK tensor according to the max topK value for each request.
//! And fill the batchSlots for the softMax kernel.
//! \param layerIdx SizeType32. The layerIdx of the EagleNet.
//! \param batchSize SizeType32. Batch size.
//! \param numInputLogits SizeType32. Number of logits from all the requests.
//! \param topKs [numInputLogits], on GPU. The topK value for each input logits.
//! \param topKOffset [batchSize], on GPU. The topK start offset for each request. Will be used to slice the output
//! draft tokens.
//! \param dynamicTreeMaxTopK SizeType32. The max topK value for all request.
//! \param numValidLogits [1], on GPU. Number of valid logits.
//! \param stream cuda stream
void invokeSetTopKsFromDyanmicTreeMaxTopK(runtime::SizeType32 layerIdx, runtime::SizeType32 batchSize,
    runtime::SizeType32 numInputLogits, runtime::SizeType32* topKs, runtime::SizeType32* topKOffset,
    runtime::SizeType32 dynamicTreeMaxTopK, runtime::SizeType32 const* numValidLogits, cudaStream_t stream);

//! \brief For Eagle-2, copy this layer's scores and draft tokenIds to pluginInputAllLayersScores and
//! pluginInputAllLayersDraftTokenIds. Will also update the pluginOutputAllLayersDraftTokenIdsPredecessor, which record
//! the predecessor(parent) node of each draft token. The index of the predecessor node related to the tree that
//! composed of draft tokens of all layers.
//!
//! \param layerIdx SizeType32. The layerIdx of the EagleNet.
//! \param mNumEagleLayers SizeType32. The number of the EagleNets.
//! \param maxDecodingDraftTokens SizeType32. Maximum number of decoding draft tokens per step per request.
//! \param batchSize SizeType32. Batch size.
//! \param dynamicTreeMaxTopK SizeType32. The number of child nodes each draft tokens expands to.
//! \param topKOffset [batchSize], on GPU. The topK start offset for each request. Will be used to slice the output
//! \param pluginInputCurrentExpandIndices [batchSize, maxDecodingDraftTokens], on GPU. The indices of the nodes that
//! expand in current layer. The index is related to the tree that composed of draft tokens of all layers.
//! \param pluginInputAllLayersScores [batchSize, mNumEagleLayers, maxDecodingDraftTokens x maxDecodingDraftTokens], on
//! GPU. Scores of draft tokens at all layers.
//! \param pluginInputAllLayersDraftTokenIds [batchSize, mNumEagleLayers,
//! maxDecodingDraftTokens x maxDecodingDraftTokens], on GPU. Draft tokensId at all layers.
//! \param pluginInputAllLayersDraftTokenIdsPredecessor [batchSize, mNumEagleLayers, maxDecodingDraftTokens x
//! maxDecodingDraftTokens], on GPU. The predecessors of draft tokens at all layers.
//! \param pluginOutputAllLayersScores [batchSize, mNumEagleLayers, maxDecodingDraftTokens x maxDecodingDraftTokens], on
//! GPU. Scores of draft tokens at all layers.
//! \param pluginOutputAllLayersDraftTokenIds [batchSize, mNumEagleLayers,
//! maxDecodingDraftTokens x maxDecodingDraftTokens], on GPU. Draft tokensId at all layers.
//! \param pluginOutputAllLayersDraftTokenIdsPredecessor [batchSize, mNumEagleLayers, maxDecodingDraftTokens x
//! maxDecodingDraftTokens], on GPU. The predecessors of draft tokens at all layers.
//! \param firstTopKOutputLogProbs
//! [numInputLogits, maxDecodingDraftTokens], on GPU. The output logprobs of the first topK sampling.
//! \param firstTopKOutputIdsPtrs [numInputLogits, maxDecodingDraftTokens], on GPU.
//! The output ids of the first topK sampling.
//! \param stream cuda stream.
void invokeCopyScoresAndDraftTokenIds(runtime::SizeType32 layerIdx, runtime::SizeType32 mNumEagleLayers,
    runtime::SizeType32 maxDecodingDraftTokens, runtime::SizeType32 batchSize, runtime::SizeType32 dynamicTreeMaxTopK,
    runtime::TokenIdType const* pluginInputCurrentExpandIndices, float const* pluginInputAllLayersScores,
    runtime::TokenIdType const* pluginInputAllLayersDraftTokenIds,
    runtime::TokenIdType const* pluginInputAllLayersDraftTokenIdsPredecessor, float* pluginOutputAllLayersScores,
    runtime::TokenIdType* pluginOutputAllLayersDraftTokenIds,
    runtime::TokenIdType* pluginOutputAllLayersDraftTokenIdsPredecessor, float* firstTopKOutputLogProbs,
    runtime::TokenIdType* firstTopKOutputIdsPtrs, cudaStream_t stream);

//! \brief Update this layer's scores with previous layer's scores.
//! \param batchSize SizeType32. Batch size.
//! \param dynamicTreeMaxTopK SizeType32. The number of child nodes each draft tokens expands to.
//! \param maxDecodingDraftTokens SizeType32. Maximum number of decoding draft tokens per step per request.
//! \param curLogProbs [batchSize * dynamicTreeMaxTopK, maxDecodingDraftTokens], on GPU. This layer's output logprob,
//! which is the output of the first topK sampling.
//! \param prevLayerScores [batchSize, maxDecodingDraftTokens], on GPU. Previous layer's scores.
//! \param stream cuda stream.
void invokeUpdateScores(runtime::SizeType32 batchSize, runtime::SizeType32 dynamicTreeMaxTopK,
    runtime::SizeType32 maxDecodingDraftTokens, float* curLogProbs, float const* prevLayerScores, cudaStream_t stream);

//! \brief Prepare the input of the second topK sampling.
//! \param batchSize SizeType32. Batch size.
//! \param dynamicTreeMaxTopK SizeType32. The number of child nodes each draft tokens expands to.
//! \param maxDecodingDraftTokens SizeType32. Maximum number of decoding draft tokens per step per request.
//! \param firstTopKOutputLogProbs [numInputLogits, maxDecodingDraftTokens], on GPU. The output logprobs of the first
//! topK sampling.
//! \param secondTopKInputScoresPtrs [batchSize], on GPU. The input buffer of the second topK sampling,
//! which is actually the updated firstTopKOutputLogProbs. Each element is a pointer, points to a
//! [maxDecodingDraftTokens] buffer.
//! \param secondTopKOutputIdsFlatten [batchSize, maxDecodingDraftTokens], on GPU. The
//! outputIds buffer of the second topK sampling.
//! \param secondTopKOutputIdsPtrs [batchSize], on GPU. The output ids
//! buffer of the second topK sampling. Each element is a pointer, points to a [maxDecodingDraftTokens] buffer.
//! \param stream cuda stream.
void invokeAssembleSecondTopKSamplingInputs(runtime::SizeType32 batchSize, runtime::SizeType32 dynamicTreeMaxTopK,
    runtime::SizeType32 maxDecodingDraftTokens, float* firstTopKOutputLogProbs, float** secondTopKInputScoresPtrs,
    runtime::TokenIdType* secondTopKOutputIdsFlatten, runtime::TokenIdType** secondTopKOutputIdsPtrs,
    cudaStream_t stream);

//! \brief Update the paths according to this layer's selected draft tokens. (Exclude the last layer)
// For each layer of EagleNet, we need to have a corresponding tree/paths.
// We will also update the next expand indices according to the results of the second topK sampling.
//! \param layerIdx SizeType32. The layerIdx of the EagleNet.
//! \param batchSize SizeType32. Batch size.
//! \param dynamicTreeMaxTopK SizeType32. The number of child nodes each draft tokens expands to.
//! \param maxDecodingTokens SizeType32. Maximum number of decoding tokens per step per request.
//! \param maxPathLen SizeType32. Maximum path len of the draft sequence.
//! \param prevPaths [batchSize, maxDecodingTokens, maxPathLen], on GPU. Previous layer's paths.
//! \param newPaths [batchSize, maxDecodingTokens, maxPathLen], on GPU. This layer's output paths,
//! grows based on prevPaths.
//! \param secondTopKOutputIdsPtrs [batchSize, maxDecodingDraftTokens], on GPU. The outputIds of the second
//! topK sampling.
//! \param pluginOutputNextExpandIndices [batchSize, maxDecodingDraftTokens], on GPU. Next layer's expand
//! indices, whichi is this layer's selected draft token indices from the second topK sampling. The index is related to
//! the tree that composed of draft tokens of all layers.
//! \param stream cuda stream.
void invokeUpdatePath(runtime::SizeType32 layerIdx, runtime::SizeType32 batchSize,
    runtime::SizeType32 dynamicTreeMaxTopK, runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 maxPathLen,
    runtime::SizeType32 const* prevPaths, runtime::SizeType32* newPaths, runtime::TokenIdType** secondTopKOutputIdsPtrs,
    runtime::TokenIdType* pluginOutputNextExpandIndices, cudaStream_t stream);

//! \brief Copy this layer's output draft tokens and scores.
//! \param layerIdx SizeType32. The layerIdx of the EagleNet.
//! \param batchSize SizeType32. Batch size.
//! \param dynamicTreeMaxTopK SizeType32. The number of child nodes each draft tokens expands to.
//! \param maxDecodingDraftTokens SizeType32. Maximum number of decoding draft tokens per step per request.
//! \param curDraftIds [batchSize][maxDecodingDraftTokens], on GPU. This layer's topK draft tokenIds.
//! if layerIdx == 0, directly use the first topK's outputIds.
//! else, use the second topK's outputIds.
//! \param pluginInputDraftIds [batchSize, maxDecodingDraftTokens], on GPU. The plugin's input buffer, which save the
//! draft tokenIds.
//! \param pluginInputDraftLens [batchSize], on GPU. The number of draft tokenIds.
//! \param pluginOutputDraftIds [batchSize, maxDecodingDraftTokens], on GPU. The plugin's output buffer,
//! which save the draft tokenIds.
//! \param pluginOutputDraftLens [batchSize], on GPU. The number of draft tokenIds.
//! \param curLayerScores [batchSize, maxDecodingDraftTokens], on GPU. This layer's scores.
//! if layerIdx == 0, directly use the first topK's logProbs.
//! else, use the second topK's logProbs.
//! \param pluginOutputCurrentScores [batchSize, maxDecodingDraftTokens], on GPU.
//! This layer's output scores, which will be used in next layer.
//! \param stream cuda stream.
void invokeUpdateDraftTokensAndLensAndCurScores(runtime::SizeType32 layerIdx, runtime::SizeType32 batchSize,
    runtime::SizeType32 dynamicTreeMaxTopK, runtime::SizeType32 maxDecodingDraftTokens,
    runtime::TokenIdType const* const* curDraftIds, runtime::TokenIdType const* pluginInputDraftIds,
    runtime::SizeType32 const* pluginInputDraftLens, runtime::TokenIdType* pluginOutputDraftIds,
    runtime::SizeType32* pluginOutputDraftLens, float const* curLayerScores, float* pluginOutputCurrentScores,
    cudaStream_t stream);

//! \brief Extract scores from secondTopKInputScoresPtrs and secondTopKOutputIdsPtrs to secondTopKOutputLogProbs
//! Extract real outputIds base on firstTopKOutputIds and secondTopKOutputIdsPtrs.
//! \param batchSize SizeType32. Batch size.
//! \param dynamicTreeMaxTopK SizeType32. The number of child nodes each draft tokens expands to.
//! \param maxDecodingDraftTokens SizeType32. Maximum number of decoding draft tokens per step per request.
//! \param secondTopKInputScoresPtrs [batchSize], on GPU. The input of the second topK sampling.
//! Each element points to a [maxDecodingDraftTokens] buffer.
//! \param secondTopKOutputIdsPtrs [batchSize], on GPU. The output ids of the second topK sampling.
//! Each element points to a [dynamicTreeMaxTopK * maxDecodingDraftTokens] buffer.
//! \param firstTopKOutputIds [batchSize * dynamicTreeMaxTopK * maxDecodingDraftTokens], on GPU. The output ids of the
//! first topK sampling.
//! \param secondTopKOutputLogProbs [batchSize, maxDecodingDraftTokens], on GPU. The output logprob
//! of the second topK sampling.
//! \param stream cuda stream.
void invokeExtractScoresAndRealDraftTokensIds(runtime::SizeType32 batchSize, runtime::SizeType32 dynamicTreeMaxTopK,
    runtime::SizeType32 maxDecodingDraftTokens, float const* const* secondTopKInputScoresPtrs,
    runtime::TokenIdType* const* secondTopKOutputIdsPtrs, runtime::TokenIdType* firstTopKOutputIds,
    float* secondTopKOutputLogProbs, cudaStream_t stream);

//! \brief Prepare the input of the thrid topK sampling.
//! \param batchSize SizeType32. Batch size.
//! \param maxDecodingDraftTokens SizeType32. Maximum number of decoding draft tokens per step per request.
//! \param mNumEagleLayers SizeType32. The number of the EagleNets.
//! \param maxNodesOnFinalTree SizeType32. The maximum number of nodes on the final tree. (exclude the root node)
//! \param thirdTopKs [batchSize]. The topK value for each request.
//! \param pluginOutputAllLayersScores [batchSize, mNumEagleLayers, maxDecodingDraftTokens x maxDecodingDraftTokens], on
//! GPU. Scores of draft tokens at all layers.
//! \param thirdTopKInputScoresPtrs [batchSize], on GPU. The input scores of the thrid topK sampling.
//! Each element points to a [mNumEagleLayers * maxDecodingDraftTokens * maxDecodingDraftTokens] buffer.
//! \param thirdTopKOutputIds [batchSize, maxDecodingDraftTokens], on GPU. The outputIds
//! of the third topK sampling. The outputIds is related to the tree that composed of draft tokens of all layers.
//! \param thirdTopKOutputIdsPtrs [batchSize], on GPU. The output ids of the third topK sampling. Each element points to
//! a [maxDecodingDraftTokens] buffer.
//! \param stream cuda stream.
void invokeAssembleThridTopKSamplingInputs(runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingDraftTokens,
    runtime::SizeType32 mNumEagleLayers, runtime::SizeType32 maxNodesOnFinalTree, runtime::SizeType32* thirdTopKs,
    float* pluginOutputAllLayersScores, float** thirdTopKInputScoresPtrs, runtime::TokenIdType* thirdTopKOutputIds,
    runtime::TokenIdType** thirdTopKOutputIdsPtrs, cudaStream_t stream);

//! \brief Reconstruct the paths at the final layers.
//! \param batchSize SizeType32. Batch size.
//! \param dynamicTreeMaxTopK SizeType32. The number of child nodes each draft tokens expands to.
//! \param maxDecodingDraftTokens SizeType32. Maximum number of decoding draft tokens per step per request.
//! \param maxDecodingTokens SizeType32. Maximum number of decoding tokens per step per request.
//! \param maxPathLen SizeType32. Maximum path len of the draft sequence.
//! \param mNumEagleLayers SizeType32. The number of the EagleNets.
//! \param maxNodesOnFinalTree SizeType32. The maximum number of nodes on the final tree. (exclude the root node)
//! \param thirdTopKOutputIdsPtrs [batchSize], on GPU. The output ids of the third topK sampling.
//! Each element points to a [maxDecodingDraftTokens] buffer.
//! \param pluginOutputAllLayersDraftTokenIdsPredecessor [batchSize, mNumEagleLayers, maxDecodingDraftTokens x
//! maxDecodingDraftTokens], on GPU. The predecessors of draft tokens at all layers.
//! \param newPaths [batchSize, maxDecodingTokens, maxPathLen], on GPU.
//! This layer's output paths, grows based on the outputs of the third samplinig.
//! \param stream cuda stream.
void invokeReconstructFinalPath(runtime::SizeType32 batchSize, runtime::SizeType32 dynamicTreeMaxTopK,
    runtime::SizeType32 maxDecodingDraftTokens, runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 maxPathLen,
    runtime::SizeType32 mNumEagleLayers, runtime::SizeType32 maxNodesOnFinalTree,
    runtime::TokenIdType* const* thirdTopKOutputIdsPtrs,
    runtime::TokenIdType* pluginOutputAllLayersDraftTokenIdsPredecessor, runtime::SizeType32* newPaths,
    cudaStream_t stream);

//! \brief Copy the last layer's selected draft tokens into plugin's output buffer.
//! \param batchSize SizeType32. Batch size.
//! \param maxDecodingDraftTokens SizeType32. Maximum number of decoding draft tokens per step per request.
//! \param mNumEagleLayers SizeType32. The number of the EagleNets.
//! \param maxNodesOnFinalTree SizeType32. The maximum number of nodes on the final tree. (exclude the root node)
//! \param thirdTopKOutputIdsPtrs [batchSize], on GPU. The output ids of the third topK sampling.
//! Each element points to a [maxDecodingDraftTokens] buffer.
//! \param pluginOutputAllLayersDraftTokenIds [batchSize, mNumEagleLayers, maxDecodingDraftTokens x
//! maxDecodingDraftTokens], on GPU. Draft tokensId at all layers.
//! \param pluginOutputDraftTokenIds [batchSize * maxDecodingDraftTokens], on GPU.
//! The plugin's output buffer, which saves the output draft tokenIds.
//! \param pluginOutputDraftLens [batchSize], on GPU. The number of the output draft tokens.
//! \param stream cuda stream.
void invokeCopyFinalDraftTokens(runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingDraftTokens,
    runtime::SizeType32 mNumEagleLayers, runtime::SizeType32 maxNodesOnFinalTree,
    runtime::TokenIdType const* const* thirdTopKOutputIdsPtrs, runtime::TokenIdType* pluginOutputAllLayersDraftTokenIds,
    runtime::TokenIdType* pluginOutputDraftTokenIds, runtime::SizeType32* pluginOutputDraftLens, cudaStream_t stream);

} // namespace tensorrt_llm::kernels::speculative_decoding
